# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Taken from https://raw.githubusercontent.com/facebookresearch/torchbeast/3f3029cf3d6d488b8b8f952964795f451a49048f/torchbeast/monobeast.py
# and modified

import os
import logging
import pprint
import time
import timeit
import traceback
import typing
import copy
import psutil
import numpy as np
import queue
import cloudpickle
from torch.multiprocessing import Pool
import threading
import json
import shutil
import signal

import torch
import multiprocessing as py_mp
from torch import multiprocessing as mp
from torch import nn
from torch.nn import functional as F

from continual_rl.policies.impala.torchbeast.core import environment
from continual_rl.policies.impala.torchbeast.core import prof
from continual_rl.policies.impala.torchbeast.core import vtrace
from continual_rl.utils.utils import Utils

Buffers = typing.Dict[str, typing.List[torch.Tensor]]


class LearnerThreadState():
    STARTING, RUNNING, STOP_REQUESTED, STOPPED = range(4)  # 学习器的四种状态：开始中、运行中、停止请求、已停止

    def __init__(self):
        """
        这个类是一个帮助类，用于管理线程之间的状态通信。
        假设学习器的状态足够原子化，不需要进一步的线程安全性。
        """
        self.state = self.STARTING
        self.lock = threading.Lock()

    def wait_for(self, desired_state_list, timeout=300):
        # 等待状态变为desired_state_list中的任何一个，或者超时
        time_passed = 0
        delta = 0.1  # 每次等待的时间（秒）

        while self.state not in desired_state_list and time_passed < timeout:
            time.sleep(delta)
            time_passed += delta

        if time_passed > timeout:
            print(
                f"Gave up on waiting due to timeout. Desired list: {desired_state_list}, current state: {self.state}")  # TODO: not print


class Monobeast():
    def __init__(self, model_flags, observation_spaces, action_spaces, policy_class):
        self._model_flags = model_flags  # model_flags就是相应策略的配置类
        self._unroll_length = model_flags.unroll_length  # 展开长度，设置为成员变量以便多进程修改

        manager = py_mp.Manager()
        self.shared_params = manager.dict()  # 用于共享参数的字典
        self.actor_is_random = False  # 用于标记actor是否是随机输出动作

        # actor_index==0的演员生成的最新一集完整的观察结果集
        self._videos_to_log = py_mp.Manager().Queue(maxsize=1)

        # 所有任务的观测空间中最大的观测空间
        self.max_observation_space = Utils.get_max_observation_space(observation_spaces)
        self.action_spaces = action_spaces

        # 将一些原始Monobest代码移到一个设置函数中，以生成类对象
        self.buffers, self.actor_model, self.learner_model, self.optimizer, self.plogger, self.logger, self.checkpointpath \
            = self.setup(model_flags, observation_spaces, action_spaces, policy_class)
        self._scheduler_state_dict = None  # Filled if we load()
        self._scheduler = None  # 特定于任务，因此在这里创建

        # 跟踪线程/进程，以便清理它们。
        self._learner_thread_states = []
        self._actor_processes = []

        # train()将被调用多次（每个任务、每个周期一次）
        # 目前的假设是，一次只应运行一个train()，并且所有其他train()都已清理干净。
        # 这些参数有助于确保这是正确的。
        self._train_loop_id_counter = 0
        self._train_loop_id_running = None

        # 如果正在重新加载任务，需要从停止的地方开始
        self.last_timestep_returned = 0

        # 在训练中创造，保存下来，从而可以干净地结束
        self.free_queue = None
        self.full_queue = None

        # Pillow sometimes pollutes the logs, see: https://github.com/python-pillow/Pillow/issues/5096
        logging.getLogger("PIL.PngImagePlugin").setLevel(logging.CRITICAL + 1)

    # Functions designed to be overridden by subclasses of Monobeast
    def on_act_unroll_complete(self, task_flags, actor_index, agent_output, env_output, new_buffers):
        """
        在每个运行act()的进程中每次rollout后调用。请注意，这种情况发生在单独的过程中，因此需要对数据进行相应的引导。
        """
        pass

    def get_batch_for_training(self, batch):
        """
        在旧batch的基础上创建新batch，并进行所需的任何修改. (例如，使用重放缓冲区中的样本进行增强。)这是在每个学习者线程中运行的。
        """
        return batch

    def custom_loss(self, task_flags, model, initial_agent_state, batch, vtrace_returns):
        """
        创建新的损失。这将添加到反向传播之前的现有损失中.
        任何返回的统计数据都将添加到记录的统计数据中。
        如果一个stat的键以“_loss”结尾，它也会自动绘制出来。
        这是在每个学习者线程中运行的。
        :return: (loss, dict of stats)
        """
        return 0, {}

    def permanent_delete(self):
        pass

    # Core Monobeast functionality
    def setup(self, model_flags, observation_spaces, action_spaces, policy_class):
        os.environ["OMP_NUM_THREADS"] = "1"
        logging.basicConfig(
            format=(
                "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
            ),
            level=0,
        )

        logger = Utils.create_logger(os.path.join(model_flags.savedir, "impala_logs.log"))
        plogger = Utils.create_logger(os.path.join(model_flags.savedir, "impala_results.log"))

        checkpointpath = os.path.join(model_flags.savedir, "model.tar")

        if model_flags.num_buffers is None:  # 为num_buffer设置合理的默认值。
            # 缓冲区数量大于等于2倍actor数量，并且大于等于batch_size
            model_flags.num_buffers = max(2 * model_flags.num_actors, model_flags.batch_size)
        if model_flags.num_actors >= model_flags.num_buffers:
            raise ValueError("缓冲区数量应该大于演员数量")
        if model_flags.num_buffers < model_flags.batch_size:
            raise ValueError("缓冲区数量应该大于等于batch_size")

        # Convert the device string into an actual device
        model_flags.device = torch.device(model_flags.device)

        model = policy_class(observation_spaces, action_spaces, model_flags)  # 创建actor的网络模型
        buffers = self.create_buffers(model_flags, self.max_observation_space.shape, model.num_actions)

        model.share_memory()

        learner_model = policy_class(
            observation_spaces, action_spaces, model_flags).to(device=model_flags.device)  # learner的网络模型

        # 创建网络优化器
        if model_flags.optimizer == "rmsprop":
            optimizer = torch.optim.RMSprop(
                learner_model.parameters(),
                lr=model_flags.learning_rate,
                momentum=model_flags.momentum,
                eps=model_flags.epsilon,
                alpha=model_flags.alpha,
            )
        elif model_flags.optimizer == "adam":
            optimizer = torch.optim.Adam(
                learner_model.parameters(),
                lr=model_flags.learning_rate,
            )
        else:
            raise ValueError(f"Unsupported optimizer type {model_flags.optimizer}.")

        return buffers, model, learner_model, optimizer, plogger, logger, checkpointpath

    def compute_baseline_loss(self, advantages):
        # 优势值的损失
        return 0.5 * torch.sum(advantages ** 2)

    def compute_entropy_loss(self, logits):
        """Return the entropy loss, i.e., the negative entropy of the policy."""
        policy = F.softmax(logits, dim=-1)
        log_policy = F.log_softmax(logits, dim=-1)
        return torch.sum(policy * log_policy)

    def compute_policy_gradient_loss(self, logits, actions, advantages):
        cross_entropy = F.nll_loss(
            F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
            target=torch.flatten(actions, 0, 1),
            reduction="none",
        )
        cross_entropy = cross_entropy.view_as(advantages)
        return torch.sum(cross_entropy * advantages.detach())

    @staticmethod
    def preprocess_env_output(env_output, model_flags, action_output=None):
        # 预处理环境输出函数
        return env_output

    def act(self, model_flags, task_flags, actor_index: int, free_queue: py_mp.Queue, full_queue: py_mp.Queue,
            model: torch.nn.Module, buffers: Buffers, initial_agent_state_buffers):
        """
        一个Actor通过模型计算动作，并将结果放入相应的缓冲区中，循环的次数取决于free_queue的长度，即需要进行数据采集的缓冲区索引数量，
        每次计算的步数为unroll_length，采集完成后将对应的缓冲区索引放入full_queue队列，用于后续提取数据
        """
        env = None
        try:
            self.logger.info("Actor %i 启动.", actor_index)
            timings = prof.Timings()  # 跟踪进展速度。
            observations_to_render = []  # 用于可视化行为的观测列表, 仅用于Actor 0

            # 创建环境并初始化
            gym_env, seed = Utils.make_env(task_flags.env_spec, create_seed=True)
            self.logger.info(f"Environment and libraries setup with seed {seed}")
            env = environment.Environment(gym_env)
            env_output = env.initial()
            env_output = self.preprocess_env_output(env_output, model_flags)  # 预处理环境输出

            # 如果传递了终止信号，请确保干净地终止env
            def end_task(*args):
                env.close()

            signal.signal(signal.SIGTERM, end_task)  # 定义接受到SIGTERM信号时的处理函数，即结束环境

            # 状态空间统一,并计算智能体的首次输出
            env_output['frame'] = Utils.padding_state(env_output['frame'], self.max_observation_space.shape)
            agent_state = model.initial_state(batch_size=1)

            # 为避免麻烦，第一次不使用随机动作
            agent_output, unused_state = model(env_output, task_flags.action_space_id, agent_state)

            while True:
                for param_key, param_value in self.shared_params.items():
                    if self.__getattribute__(param_key) != param_value:
                        self.__setattr__(param_key, param_value)
                        self.logger.info(f"Actor {actor_index} 更新成员变量 {param_key} 为 {param_value}")
                        if param_key == "actor_is_random" and param_value == False:
                            # 关闭随机探索后需要重新创建缓冲区，避免随机探索的数据被用于训练
                            self.buffers = self.create_buffers(model_flags, self.max_observation_space.shape,
                                                               model.num_actions)
                            # 并且初始化得到的数据
                            env_output = env.initial()
                            env_output = self.preprocess_env_output(env_output, model_flags)  # 预处理环境输出
                            env_output['frame'] = Utils.padding_state(env_output['frame'],
                                                                      self.max_observation_space.shape)
                            agent_state = model.initial_state(batch_size=1)
                            agent_output, unused_state = model(env_output, task_flags.action_space_id, agent_state)
                        # self.logger.info(f"当前实际动作空间大小为 {self.action_spaces[task_flags.action_space_id].n}")

                index = free_queue.get()
                if index is None:
                    break

                # 将环境的输出数据和智能体的输出数据存储到缓冲区中
                for key in env_output:
                    buffers[key][index][0, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][0, ...] = agent_output[key]
                for i, tensor in enumerate(agent_state):
                    initial_agent_state_buffers[index][i][...] = tensor

                # 进行unroll_length步的展开
                for t in range(self._unroll_length):
                    timings.reset()

                    with torch.no_grad():
                        # 计算模型的输出
                        if self.actor_is_random:
                            agent_output, unused_state = {
                                "actual_action": torch.randint(0, self.action_spaces[task_flags.action_space_id].n,
                                                               (1,))}, ()
                        else:
                            agent_output, agent_state = model(env_output, task_flags.action_space_id, agent_state)

                    timings.time("model")

                    if "actual_action" in agent_output:
                        # 如果模型输出了实际动作，那么直接使用实际动作
                        env_output = env.step(agent_output["actual_action"])
                    else:
                        env_output = env.step(agent_output["action"])  # 执行动作
                    env_output = self.preprocess_env_output(env_output, model_flags, agent_output)  # 预处理环境输出

                    # 状态空间统一
                    env_output['frame'] = Utils.padding_state(env_output['frame'], self.max_observation_space.shape)

                    timings.time("step")

                    # 在缓冲区中存储输出数据
                    for key in env_output:
                        buffers[key][index][t + 1, ...] = env_output[key]
                    for key in agent_output:
                        buffers[key][index][t + 1, ...] = agent_output[key]

                    # 对于0号Actor,需要保存视频
                    if actor_index == 0:
                        if env_output['done'].squeeze():
                            try:
                                self._videos_to_log.get(timeout=1)
                            except queue.Empty:
                                pass
                            except (FileNotFoundError, ConnectionRefusedError, ConnectionResetError, RuntimeError) as e:
                                # Sometimes it seems like the videos_to_log socket fails. Since video logging is not
                                # mission-critical, just let it go.
                                self.logger.warning(
                                    f"Video logging socket seems to have failed with error {e}. Aborting video log.")
                                pass

                            # 存储用于渲染的观测
                            self._videos_to_log.put(copy.deepcopy(observations_to_render))
                            observations_to_render.clear()

                        # 得到用于渲染的观测，来自0号actor
                        observations_to_render.append(env_output['frame'].squeeze(0).squeeze(0)[-1])

                    timings.time("write")

                new_buffers = {key: buffers[key][index] for key in buffers.keys()}  # 创建新的缓冲区用于额外的操作
                self.on_act_unroll_complete(task_flags, actor_index, agent_output, env_output, new_buffers)
                full_queue.put(index)

            if actor_index == 0:
                # 使用0号Actor来输出日志
                self.logger.info("Actor %i: %s", actor_index, timings.summary())

        except KeyboardInterrupt:
            pass  # Return silently.
        except Exception as e:
            self.logger.error(f"Exception in worker process {actor_index}: {e}")
            traceback.print_exc()
            print()
            raise e
        finally:
            self.logger.info(f"Finalizing actor {actor_index}")
            if env is not None:
                env.close()

    def get_batch(self, flags, free_queue: py_mp.Queue, full_queue: py_mp.Queue, buffers: Buffers,
                  initial_agent_state_buffers, timings, lock):
        """
        从Actor收集数据的多个缓冲区中构建一个Batch大小的数据
        """
        with lock:
            # 获取线程锁并得到就绪队列中的索引序列，序列大小为batch_size
            timings.time("lock")
            indices = [full_queue.get() for _ in range(flags.batch_size)]
            timings.time("dequeue")
        batch = {
            # 从每个缓冲区中获取数据，构成一个batch
            key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers
        }
        initial_agent_state = (
            torch.cat(ts, dim=1)
            for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
        )
        timings.time("batch")
        for m in indices:
            # 将提取了数据的缓冲区索引放回待用队列
            free_queue.put(m)
        timings.time("enqueue")

        batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
        initial_agent_state = tuple(
            t.to(device=flags.device, non_blocking=True) for t in initial_agent_state
        )
        timings.time("device")
        return batch, initial_agent_state

    def get_policy_logits(self, batch_output):
        # 获取动作概率
        return batch_output["policy_logits"]

    def compute_loss(self, model_flags, task_flags, learner_model, batch, initial_agent_state, with_custom_loss=True):
        """
        计算损失函数的值。这个函数首先从模型中获取输出，然后根据这些输出和批处理数据计算损失。

        参数:
        model_flags (object): 包含模型相关的标志。
        task_flags (object): 包含任务相关的标志。
        learner_model (nn.Module): 学习模型。
        batch (dict): 包含批处理数据的字典。
        initial_agent_state (tuple): 初始的智能体状态。
        with_custom_loss (bool): 是否计算自定义损失并添加到总损失中，默认为True。

        返回:
        total_loss (torch.Tensor): 总损失。
        stats (dict): 包含各种统计信息的字典。
        pg_loss (torch.Tensor): 策略梯度损失。
        baseline_loss (torch.Tensor): 基线损失。
        """
        # 注意，action_space_id并没有真正使用——它用于生成动作，但这里使用的是已经计算和执行的动作
        learner_outputs, unused_state = learner_model(batch, task_flags.action_space_id, initial_agent_state)

        # 取最后一个值函数切片用于引导。
        bootstrap_value = learner_outputs["baseline"][-1]

        # 从obs[t] -> action[t] 转换为 action[t] -> obs[t]。
        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()}

        rewards = batch["reward"]

        # 如果设置了奖励归一化
        if model_flags.normalize_reward:
            # 奖励归一化
            learner_model.update_running_moments(rewards)
            rewards /= learner_model.get_running_std()

        # 奖励裁剪
        if model_flags.reward_clipping == "abs_one":
            # 限制奖励在-1和1之间
            clipped_rewards = torch.clamp(rewards, -1, 1)
        elif model_flags.reward_clipping == "none":
            clipped_rewards = rewards

        discounts = (~batch["done"]).float() * model_flags.discounting

        # 计算回报并计算损失
        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=self.get_policy_logits(batch),
            target_policy_logits=self.get_policy_logits(learner_outputs),
            actions=batch["actual_action"] if "actual_action" in batch else batch["action"],  # 若存在实际行为则使用实际行为计算损失
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=bootstrap_value,
        )

        pg_loss = self.compute_policy_gradient_loss(
            self.get_policy_logits(learner_outputs),
            batch["actual_action"] if "actual_action" in batch else batch["action"],
            vtrace_returns.pg_advantages,
        )
        baseline_loss = model_flags.baseline_cost * self.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs["baseline"]
        )
        entropy_loss = model_flags.entropy_cost * self.compute_entropy_loss(
            self.get_policy_logits(learner_outputs)
        )

        total_loss = pg_loss + baseline_loss + entropy_loss
        stats = {
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
            "vtrace_vs_mean": vtrace_returns.vs.mean().item(),
        }

        # 如果需要计算自定义损失并添加到Vtrace损失中
        if with_custom_loss:
            custom_loss, custom_stats = self.custom_loss(task_flags, learner_model, initial_agent_state, batch,
                                                         vtrace_returns)
            total_loss += custom_loss
            stats.update(custom_stats)

        return total_loss, stats, pg_loss, baseline_loss

    def learn(self, model_flags, task_flags, actor_model, learner_model, batch, initial_agent_state, optimizer,
              scheduler, lock, ):
        """Performs a learning (optimization) step."""
        with lock:
            # Only log the real batch of new data, not the manipulated version for training, so save it off
            batch_for_logging = copy.deepcopy(batch)  # 用于记录日志的batch数据

            # 对需要用于训练的batch数据进行处理
            batch = self.get_batch_for_training(batch)

            total_loss, stats, _, _ = self.compute_loss(model_flags, task_flags, learner_model, batch,
                                                        initial_agent_state)

            # 如果我们使用EpisodicLifeEnv（对于Atari），则episode_return可能是nan，其中episode_return在游戏结束前都是nan。
            batch_done_flags = batch_for_logging["done"] * ~torch.isnan(batch_for_logging["episode_return"])
            episode_returns = batch_for_logging["episode_return"][batch_done_flags]
            stats.update({
                "episode_returns": tuple(episode_returns.cpu().numpy()),
                "mean_episode_return": torch.mean(episode_returns).item(),
                "total_loss": total_loss.item(),
            })

            optimizer.zero_grad()
            total_loss.backward()

            norm = nn.utils.clip_grad_norm_(learner_model.parameters(), model_flags.grad_norm_clipping)
            stats["total_norm"] = norm.item()

            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            actor_model.load_state_dict(learner_model.state_dict())  # * 更新actor模型参数为最新的learner模型参数
            return stats

    def create_buffer_specs(self, unroll_length, obs_shape, num_actions):
        # 配置缓冲区需要的键和每个值的数据类型
        T = unroll_length
        specs = dict(
            frame=dict(size=(T + 1, *obs_shape), dtype=torch.uint8),
            reward=dict(size=(T + 1,), dtype=torch.float32),
            done=dict(size=(T + 1,), dtype=torch.bool),
            episode_return=dict(size=(T + 1,), dtype=torch.float32),
            episode_step=dict(size=(T + 1,), dtype=torch.int32),
            policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32),
            baseline=dict(size=(T + 1,), dtype=torch.float32),
            uncertainty=dict(size=(T + 1,), dtype=torch.float32),
            last_action=dict(size=(T + 1,), dtype=torch.int64),
            action=dict(size=(T + 1,), dtype=torch.int64),
        )
        return specs

    def create_buffers(self, flags, obs_shape, num_actions) -> Buffers:
        # 创建缓冲区
        specs = self.create_buffer_specs(flags.unroll_length, obs_shape, num_actions)
        buffers: Buffers = {key: [] for key in specs}
        for _ in range(flags.num_buffers):
            for key in buffers:
                buffers[key].append(torch.empty(**specs[key]).share_memory_())  # 分配共享内存作为缓冲区空间
        return buffers

    def create_learn_threads(self, batch_and_learn, stats_lock, thread_free_queue, thread_full_queue):
        # 创建并启动多个学习器线程
        learner_thread_states = [LearnerThreadState() for _ in range(self._model_flags.num_learner_threads)]
        batch_lock = threading.Lock()
        learn_lock = threading.Lock()
        threads = []
        for i in range(self._model_flags.num_learner_threads):
            thread = threading.Thread(
                target=batch_and_learn, name="batch-and-learn-%d" % i, args=(
                    i, stats_lock, learner_thread_states[i], batch_lock, learn_lock, thread_free_queue,
                    thread_full_queue)
            )
            thread.start()
            threads.append(thread)
        return threads, learner_thread_states

    def cleanup(self):
        # We've finished the task, so reset the appropriate counter
        self.logger.info("正在完成任务，将timestep_return设置为0")
        self.last_timestep_returned = 0

        # Ensure the training loop will end
        self._train_loop_id_running = None

        self._cleanup_parallel_workers()

    def _cleanup_parallel_workers(self):
        # 结束所有并行的Actor进程和学习器线程
        self.logger.info("清理Actors")

        # 恢复Actor进程以便放松结束信号
        for actor_index, actor in enumerate(self._actor_processes):
            self.free_queue.put(None)
            try:
                actor_process = psutil.Process(actor.pid)
                actor_process.resume()
            except (psutil.NoSuchProcess, psutil.AccessDenied, ValueError):
                # If it's already dead, just let it go
                pass

        # 试着等Actor干净地结束。如果他们不这样做，尝试强制终止
        for actor_index, actor in enumerate(self._actor_processes):
            try:
                actor.join(30)  # Give up on waiting eventually

                if actor.exitcode is None:
                    actor.terminate()

                actor.close()
                self.logger.info(f"[Actor {actor_index}] Cleanup complete")
            except ValueError:  # if actor already killed
                pass
            except AttributeError:  # ForkProcess doesn't have close()
                pass

        # 终止学习器，这样就不会在完成（或某件事失败）时不断地得出结果
        self.logger.info("清理Learners")
        for thread_state in self._learner_thread_states:
            thread_state.state = LearnerThreadState.STOP_REQUESTED

        self.logger.info("完成所有并行Workers的清理")

    def resume_actor_processes(self, ctx, task_flags, actor_processes, free_queue, full_queue,
                               initial_agent_state_buffers):
        # Copy, so iterator and what's being updated are separate
        actor_processes_copy = actor_processes.copy()
        for actor_index, actor in enumerate(actor_processes_copy):
            allowed_statuses = ["running", "sleeping", "disk-sleep"]
            actor_pid = None  # actor.pid fails with ValueError if the process is already closed

            try:
                actor_pid = actor.pid
                actor_process = psutil.Process(actor_pid)
                actor_process.resume()
                recreate_actor = not actor_process.is_running() or actor_process.status() not in allowed_statuses
            except (psutil.NoSuchProcess, psutil.AccessDenied, ValueError):
                self.logger.warn(
                    f"Actor with pid {actor_pid} in actor index {actor_index} was unable to be restarted. Recreating...")
                recreate_actor = True

            if recreate_actor:
                # Kill the original ctx.Process object, rather than the one attached to by pid
                # Attempting to fix an issue where the actor processes are hanging, CPU util shows zero
                try:
                    actor_processes[actor_index].kill()
                    actor_processes[actor_index].join()
                    actor_processes[actor_index].close()
                except ValueError:  # if actor already killed
                    pass

                self.logger.warn(
                    f"Actor actor index {actor_index} was unable to be restarted. Recreating...")
                new_actor = ctx.Process(
                    target=self.act,
                    args=(
                        self._model_flags,
                        task_flags,
                        actor_index,
                        free_queue,
                        full_queue,
                        self.actor_model,
                        self.buffers,
                        initial_agent_state_buffers,
                    ),
                )
                new_actor.start()
                actor_processes[actor_index] = new_actor

    def save(self, output_path):
        if self._model_flags.disable_checkpoint:
            return

        model_file_path = os.path.join(output_path, "model.tar")

        # Back up previous model (sometimes they can get corrupted)
        if os.path.exists(model_file_path):
            shutil.copyfile(model_file_path, os.path.join(output_path, "model_bak.tar"))

        # Save the model
        self.logger.info(f"Saving model to {output_path}")

        checkpoint_data = {
            "model_state_dict": self.actor_model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }
        if self._scheduler is not None:
            checkpoint_data["scheduler_state_dict"] = self._scheduler.state_dict()

        torch.save(checkpoint_data, model_file_path)

        # Save metadata
        metadata_path = os.path.join(output_path, "impala_metadata.json")
        metadata = {"last_timestep_returned": self.last_timestep_returned}
        with open(metadata_path, "w+") as metadata_file:
            json.dump(metadata, metadata_file)

    def load(self, output_path):
        model_file_path = os.path.join(output_path, "model.tar")
        if os.path.exists(model_file_path):
            self.logger.info(f"Loading model from {output_path}")
            try:
                checkpoint = torch.load(model_file_path, map_location="cpu")
            except RuntimeError as e:
                assert "PytorchStreamReader" in str(e)
                self.logger.warn("Save file corrupted, resuming from backup. Likely the run ended during model save.")
                model_file_path = os.path.join(output_path, "model_bak.tar")
                checkpoint = torch.load(model_file_path, map_location="cpu")

            self.actor_model.load_state_dict(checkpoint["model_state_dict"])
            self.learner_model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

            if self._model_flags.use_scheduler:
                self._scheduler_state_dict = checkpoint.get("scheduler_state_dict", None)

                if self._scheduler_state_dict is None:
                    # Tracked by issue #109
                    self.logger.warn("No scheduler state dict found to load when one was expected.")
        else:
            self.logger.info("No model to load, starting from scratch")

        # Load metadata
        metadata_path = os.path.join(output_path, "impala_metadata.json")
        if os.path.exists(metadata_path):
            self.logger.info(f"Loading impala metdata from {metadata_path}")
            with open(metadata_path, "r") as metadata_file:
                metadata = json.load(metadata_file)

            self.last_timestep_returned = metadata["last_timestep_returned"]

    def train(self, task_flags):  # pylint: disable=too-many-branches, too-many-statements
        T = self._model_flags.unroll_length
        B = self._model_flags.batch_size

        def lr_lambda(epoch):
            return 1 - min(epoch * T * B, task_flags.total_steps) / task_flags.total_steps

        if self._model_flags.use_scheduler:
            self._scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
        else:
            self._scheduler = None

        if self._scheduler is not None and self._scheduler_state_dict is not None:
            self.logger.info("Loading scheduler state dict")
            self._scheduler.load_state_dict(self._scheduler_state_dict)
            self._scheduler_state_dict = None

        # Add initial RNN state.
        initial_agent_state_buffers = []
        for _ in range(self._model_flags.num_buffers):
            state = self.actor_model.initial_state(batch_size=1)
            for t in state:
                t.share_memory_()
            initial_agent_state_buffers.append(state)

        # * ---第一步：设置演员进程并启动它们---
        self._actor_processes = []
        ctx = mp.get_context("fork")

        # See: https://stackoverflow.com/questions/47085458/why-is-multiprocessing-queue-get-so-slow for why Manager
        self.free_queue = py_mp.Manager().Queue()
        self.full_queue = py_mp.Manager().Queue()

        for i in range(self._model_flags.num_actors):
            # 每个Actor执行act方法
            actor = ctx.Process(
                target=self.act,
                args=(
                    self._model_flags,
                    task_flags,
                    i,
                    self.free_queue,
                    self.full_queue,
                    self.actor_model,
                    self.buffers,
                    initial_agent_state_buffers,
                ),
            )
            actor.start()
            self._actor_processes.append(actor)

        stat_keys = [
            "total_loss",
            "mean_episode_return",
            "pg_loss",
            "baseline_loss",
            "entropy_loss",
            "vtrace_vs_mean"
        ]
        self.logger.info("# Step\t%s", "\t".join(stat_keys))

        step, collected_stats = self.last_timestep_returned, {}
        self._stats_lock = threading.Lock()

        # * ---第二步：定义学习器线程函数并启动多线程---
        def batch_and_learn(i, lock, thread_state, batch_lock, learn_lock, thread_free_queue, thread_full_queue):

            try:
                nonlocal step, collected_stats
                timings = prof.Timings()

                while True:
                    # 如果请求停止，请改变状态为停止并结束线程
                    with thread_state.lock:
                        if thread_state.state == LearnerThreadState.STOP_REQUESTED:
                            thread_state.state = LearnerThreadState.STOPPED
                            return

                        thread_state.state = LearnerThreadState.RUNNING

                    timings.reset()
                    batch, agent_state = self.get_batch(
                        self._model_flags,
                        thread_free_queue,
                        thread_full_queue,
                        self.buffers,
                        initial_agent_state_buffers,
                        timings,
                        batch_lock,
                    )
                    stats = self.learn(
                        self._model_flags, task_flags, self.actor_model, self.learner_model, batch, agent_state,
                        self.optimizer, self._scheduler, learn_lock
                    )
                    timings.time("learn")
                    with lock:
                        step += T * B
                        to_log = dict(step=step)
                        to_log.update({k: stats[k] for k in stat_keys if k in stats})
                        self.plogger.info(to_log)

                        # We might collect stats more often than we return them to the caller, so collect them all
                        for key in stats.keys():
                            if key not in collected_stats:
                                collected_stats[key] = []

                            if isinstance(stats[key], tuple) or isinstance(stats[key], list):
                                collected_stats[key].extend(stats[key])
                            else:
                                collected_stats[key].append(stats[key])
            except Exception as e:
                self.logger.error(f"Learner线程失败，出现异常{e}")
                raise e

            if i == 0:
                self.logger.info("Batch and learn: %s", timings.summary())

            thread_state.state = LearnerThreadState.STOPPED

        for m in range(self._model_flags.num_buffers):
            # 初始化free_queue为每个缓冲区的索引
            self.free_queue.put(m)

        self.explore()  # 学习之前先进行探索

        # 创建学习器线程
        threads, self._learner_thread_states = self.create_learn_threads(batch_and_learn, self._stats_lock,
                                                                         self.free_queue, self.full_queue)

        # * ---第三步：启动训练循环---
        # 为训练循环创建ID,并且只有当它是活动ID时才循环
        assert self._train_loop_id_running is None, "试图在另一个训练循环处于活动状态时启动一个训练循环。"
        train_loop_id = self._train_loop_id_counter
        self._train_loop_id_counter += 1
        self._train_loop_id_running = train_loop_id
        self.logger.info(f"启动训练循环：ID {train_loop_id}")

        timer = timeit.default_timer
        try:
            while self._train_loop_id_running == train_loop_id:
                start_step = step
                start_time = timer()
                time.sleep(self._model_flags.seconds_between_yields)

                # Copy right away, because there's a race where stats can get re-set and then certain things set below will be missing (eg "step")
                with self._stats_lock:
                    # 复制收集的信息并清空
                    stats_to_return = copy.deepcopy(collected_stats)
                    collected_stats.clear()

                sps = (step - start_step) / (timer() - start_time)

                # Aggregate our collected values. Do it with mean so it's not sensitive to the number of times
                # learning occurred in the interim
                mean_return = np.array(stats_to_return.get("episode_returns", [np.nan])).mean()
                stats_to_return["mean_episode_return"] = mean_return

                # Make a copy of the keys so we're not updating it as we iterate over it
                for key in list(stats_to_return.keys()).copy():
                    if key.endswith("loss") or key == "total_norm":
                        # Replace with the number we collected and the mean value, otherwise the logs are very verbose
                        stats_to_return[f"{key}_count"] = len(np.array(stats_to_return.get(key, [])))
                        stats_to_return[key] = np.array(stats_to_return.get(key, [np.nan])).mean()

                self.logger.info(
                    "Steps %i @ %.1f SPS. Mean return %f. Stats:\n%s",
                    step,
                    sps,
                    mean_return,
                    pprint.pformat(stats_to_return),
                )
                stats_to_return["step"] = step
                stats_to_return["step_delta"] = step - self.last_timestep_returned

                try:
                    video = self._videos_to_log.get(block=False)  # 获取actor得到的观测video
                    stats_to_return["video"] = video
                except queue.Empty:
                    pass
                except (FileNotFoundError, ConnectionRefusedError, ConnectionResetError, RuntimeError) as e:
                    # Sometimes it seems like the videos_to_log socket fails. Since video logging is not
                    # mission-critical, just let it go.
                    self.logger.warning(
                        f"Video logging socket seems to have failed with error {e}. Aborting video log.")
                    pass

                # This block sets us up to yield our results in batches, pausing everything while yielded.
                if self.last_timestep_returned != step:
                    self.last_timestep_returned = step

                    # Stop learn threads, they are recreated after yielding.
                    # Do this before the actors in case we need to do a last batch
                    self.logger.info("Stopping learners")
                    for thread_id, thread_state in enumerate(self._learner_thread_states):
                        wait = False
                        with thread_state.lock:
                            if thread_state.state != LearnerThreadState.STOPPED and threads[thread_id].is_alive():
                                thread_state.state = LearnerThreadState.STOP_REQUESTED
                                wait = True

                        # Wait for it to stop, otherwise we have training overlapping with eval, and possibly
                        # the thread creation below
                        if wait:
                            thread_state.wait_for([LearnerThreadState.STOPPED], timeout=30)

                    # The actors will keep going unless we pause them, so...do that.
                    if self._model_flags.pause_actors_during_yield:
                        for actor in self._actor_processes:
                            psutil.Process(actor.pid).suspend()

                    # Make sure the queue is empty (otherwise things can get dropped in the shuffle)
                    # (Not 100% sure relevant but:) https://stackoverflow.com/questions/19257375/python-multiprocessing-queue-put-not-working-for-semi-large-data
                    while not self.free_queue.empty():
                        try:
                            self.free_queue.get(block=False)
                        except queue.Empty:
                            # Race between empty check and get, I guess
                            break

                    while not self.full_queue.empty():
                        try:
                            self.full_queue.get(block=False)
                        except queue.Empty:
                            # Race between empty check and get, I guess
                            break

                    yield stats_to_return  # 返回当前训练循环的结果

                    # Ensure everything is set back up to train
                    self.actor_model.train()
                    self.learner_model.train()

                    # Resume the actors. If one is dead, replace it with a new one
                    if self._model_flags.pause_actors_during_yield:
                        self.resume_actor_processes(ctx, task_flags, self._actor_processes, self.free_queue,
                                                    self.full_queue,
                                                    initial_agent_state_buffers)

                    # Resume the learners by creating new ones
                    self.logger.info("重新启动Learners")
                    threads, self._learner_thread_states = self.create_learn_threads(batch_and_learn, self._stats_lock,
                                                                                     self.free_queue, self.full_queue)
                    self.logger.info("重新启动完成")

                    for m in range(self._model_flags.num_buffers):
                        # 重新初始化free_queue
                        self.free_queue.put(m)
                    self.logger.info("Free queue 重新填充")

        except KeyboardInterrupt:
            pass

        finally:
            self._cleanup_parallel_workers()
            for thread in threads:
                thread.join()
            self.logger.info("学习在%d步后完成。", step)

    @staticmethod
    def _collect_test_episode(pickled_args):
        task_flags, logger, model, observation_shape, preprocess_env_output = cloudpickle.loads(pickled_args)

        if task_flags.mode == "test_render":
            # 以渲染模式创建对象
            gym_env, seed = Utils.make_env(task_flags.render_env_spec, create_seed=True)
        else:
            gym_env, seed = Utils.make_env(task_flags.env_spec, create_seed=True)
        logger.info(f"Environment and libraries setup with seed {seed}")
        env = environment.Environment(gym_env)
        env_output = env.initial()
        env_output = preprocess_env_output(env_output, model._model_flags)

        # 统一状态空间
        env_output['frame'] = Utils.padding_state(env_output['frame'], observation_shape)
        done = False
        step = 0
        returns = []
        render_images = []

        while not done:
            if task_flags.mode == "test_render":
                # 调用环境的渲染方法得到环境渲染的图像
                render_images.append(env.gym_env.render())
            agent_outputs = model(env_output, task_flags.action_space_id)
            policy_outputs, _ = agent_outputs

            if "actual_action" in policy_outputs:
                env_output = env.step(policy_outputs["actual_action"])
            else:
                env_output = env.step(policy_outputs["action"])
            env_output = preprocess_env_output(env_output, model._model_flags, policy_outputs)

            # 统一状态空间
            env_output['frame'] = Utils.padding_state(env_output['frame'], observation_shape)
            step += 1
            done = env_output["done"].item() and not torch.isnan(env_output["episode_return"])

            # NaN if the done was "fake" (e.g. Atari). We want real scores here so wait for the real return.
            if done:
                returns.append(env_output["episode_return"].item())
                logger.info(
                    "Episode ended after %d steps. Return: %.1f",
                    env_output["episode_step"].item(),
                    env_output["episode_return"].item(),
                )

        env.close()
        if task_flags.mode == "test_render":
            return step, returns, render_images
        else:
            return step, returns

    def test(self, task_flags, num_episodes: int = 10):
        if not self._model_flags.no_eval_mode:
            self.actor_model.eval()

        returns = []
        render_images = []
        step = 0

        if task_flags.mode == "test_render":
            # 对于渲染模式不需要并行
            for episode_id in range(num_episodes):
                pickled_args = cloudpickle.dumps((task_flags, self.logger, self.actor_model,
                                                  self.max_observation_space.shape, self.preprocess_env_output))
                episode_step, episode_returns, render_images = self._collect_test_episode(pickled_args)
                step += episode_step
                returns.extend(episode_returns)
                render_images.extend(render_images)
            stats = {"episode_returns": returns, "step": step, "num_episodes": len(returns),
                     "render_images": render_images}
        else:
            # for episode_id in range(num_episodes):
            #     pickled_args = cloudpickle.dumps((task_flags, self.logger, self.actor_model,
            #                                       self.max_observation_space.shape, self.preprocess_env_output))
            #     episode_step, episode_returns = self._collect_test_episode(pickled_args)
            #     step += episode_step
            #     returns.extend(episode_returns)
            # stats = {"episode_returns": returns, "step": step, "num_episodes": len(returns)}
            # 将需要运行的episode分解为num_parallel的批处理，这些批处理将同时运行
            for batch_start_id in range(0, num_episodes, self._model_flags.eval_episode_num_parallel):
                # If we are in the last batch, only do the necessary number, otherwise do the max num in parallel
                batch_num_episodes = min(num_episodes - batch_start_id, self._model_flags.eval_episode_num_parallel)

                with Pool(processes=batch_num_episodes) as pool:
                    async_objs = []
                    for episode_id in range(batch_num_episodes):
                        pickled_args = cloudpickle.dumps(
                            (task_flags, self.logger, self.actor_model, self.max_observation_space.shape,
                             self.preprocess_env_output))
                        async_obj = pool.apply_async(self._collect_test_episode, (pickled_args,))
                        async_objs.append(async_obj)

                    for async_obj in async_objs:
                        result = async_obj.get()
                        episode_step, episode_returns = result[:2]
                        step += episode_step
                        returns.extend(episode_returns)

            stats = {"episode_returns": returns, "step": step, "num_episodes": len(returns)}

        self.logger.info(
            "Average returns over %i episodes: %.1f", len(returns), sum(returns) / len(returns)
        )
        yield stats

    def explore(self):
        pass
